# import packages
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited', wvarb = 'unitwt'):
	model = smf.wls(outcome+' ~ '+pbvarb, dataset, weights = dataset[wvarb]).fit(cov_type = 'HC1')
	coef = model.params[pbvarb]
	se = model.bse[pbvarb]
	black_aud_rate = model.params[0] + model.params[1]
	non_black_aud_rate = model.params[0] 
	return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited', wvarb = 'unitwt'):
	black_aud_rate = (dataset[pbvarb]*dataset[outcome]*dataset[wvarb]).sum()/(dataset[pbvarb]*dataset[wvarb]).sum()
	non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]*dataset[wvarb]).sum()/((1-dataset[pbvarb])*dataset[wvarb]).sum()
	est =  black_aud_rate - non_black_aud_rate
	return est, black_aud_rate, non_black_aud_rate

# read in oracle data
df1 = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_oracle_selected_taxpayer_ids_alt_outcome.csv')
df1 = df1[['taxpayer_id_new', 'aud_ind', 'base_weight', 'predicted_prob_black', 'predicted_prob_nonblack']]

# for the person that was partially audited, adjust the taxpayer_ids so that we consider this 2 different people
df1[df1.taxpayer_id_new == 2789524184]
df1[(df1.taxpayer_id_new == 2789524184) & (df1.aud_ind == 1)]
df1['taxpayer_id_new'] = np.where((df1.taxpayer_id_new == 2789524184) & (df1.aud_ind == 1), df1['taxpayer_id_new'] + 1, df1['taxpayer_id_new'])

# save this data to add to prediction data
rows_for_df2 = df1[(df1.taxpayer_id_new == 2789524184) | (df1.taxpayer_id_new == 2789524185)]

# read in prediction data
df2 = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_reg_selected_taxpayer_ids_alt_outcome.csv')
df2 = df2[['taxpayer_id_new', 'aud_ind', 'base_weight', 'predicted_prob_black', 'predicted_prob_nonblack']]

# for the person that was partially audited, adjust the taxpayer_ids so that we consider this 2 different people
df2[df2.taxpayer_id_new == 2605542772]
df2[(df2.taxpayer_id_new == 2605542772) & (df2.aud_ind == 1)]
df2['taxpayer_id_new'] = np.where((df2.taxpayer_id_new == 2605542772) & (df2.aud_ind == 1), df2['taxpayer_id_new'] + 1, df2['taxpayer_id_new'])

# save this data to add to oracle data
rows_for_df1 = df2[(df2.taxpayer_id_new == 2605542772) | (df2.taxpayer_id_new == 2605542773)]

# now add partially audited rows to the complement dataset
df1[df1.taxpayer_id_new == 2605542772] # aud_ind is zero!
rows_for_df1['aud_ind'] = 0

df1 = df1[df1.taxpayer_id_new != 2605542772]
df1 = pd.concat([df1, rows_for_df1])

df2[df2.taxpayer_id_new == 2789524184] # again aud_ind is zero
rows_for_df2['aud_ind'] = 0

df2 = df2[df2.taxpayer_id_new != 2789524184]
df2 = pd.concat([df2, rows_for_df2])

# check that everything looks correct
weird_taxpayer_ids = [2789524184, 2789524185, 2605542772, 2605542773]
df1_small = df1[df1.taxpayer_id_new.isin(weird_taxpayer_ids)]
df2_small = df2[df2.taxpayer_id_new.isin(weird_taxpayer_ids)] # this looks good

##########################
# Prob version, Table A.16
##########################

# write out results
sys.stdout = open("/REDACTED/excess_disp_decomp_prob.txt", "w")
### Oracle selection rates:
prob_black_audit_rate = chenEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[1]
prob_nonblack_audit_rate = chenEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[2]
prob_audit_disparity = chenEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[0]

print("Oracle Selection Rate")
print("Black: " + str(prob_black_audit_rate*100))
print("Non-Black: " + str(prob_nonblack_audit_rate*100))



### FPR:
# subset to compliant TPs in the prediction data
comp_pop = df1[df1.aud_ind == 0]
comp_taxpayer_ids = comp_pop.taxpayer_id_new.unique().tolist()

df2_comp = df2[df2.taxpayer_id_new.isin(comp_taxpayer_ids)]


black_fpr = chenEstimate(df2_comp, 'predicted_prob_black', 'aud_ind', 'base_weight')[1]
nonblack_fpr = chenEstimate(df2_comp, 'predicted_prob_black', 'aud_ind', 'base_weight')[2]

print("\nFalse Positive Rate")
print("Black: " + str(black_fpr*100))
print("Non-Black: " + str(nonblack_fpr*100))


### FNR:
# subset to non-compliant TPs in the prediction data
noncomp_pop = df1[df1.aud_ind == 1]
noncomp_taxpayer_ids = noncomp_pop.taxpayer_id_new.unique().tolist()

df2_noncomp = df2[df2.taxpayer_id_new.isin(noncomp_taxpayer_ids)]

#df2_noncomp['not_aud'] = np.where(df2_noncomp.aud_ind == 0, 1, 0)

black_fnr = 1 - chenEstimate(df2_noncomp, 'predicted_prob_black', 'aud_ind', 'base_weight')[1]
nonblack_fnr = 1 - chenEstimate(df2_noncomp, 'predicted_prob_black', 'aud_ind', 'base_weight')[2]

print("\nFalse Negative Rate")
print("Black: " + str(black_fnr*100))
print("Non-Black: " + str(nonblack_fnr*100))

# combine values to get the thre terms of the decomposition
first_term = ((1-prob_nonblack_audit_rate)*(black_fpr - nonblack_fpr))*100
second_term  = (prob_nonblack_audit_rate * (nonblack_fnr - black_fnr))*100
third_term = -(prob_audit_disparity * (black_fnr + black_fpr))*100

print("\nExcess Disparity Contribution from ... ")
print("Scaled Difference in False Positive Rates: " + str(first_term))
print("Scaled Difference in False Negative Rates: " + str(second_term))
print("Attenuation of Oracle Disparity: " + str(third_term))

total_excess_disparity = first_term + second_term + third_term

print("\nTotal Excess Disparity: " + str(total_excess_disparity))

oracle_disp = chenEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[0]
prediction_disp = chenEstimate(df2, 'predicted_prob_black', 'aud_ind', 'base_weight')[0]

print("\nOracle Disparity: " + str(oracle_disp*100))
print("\nPrediction Model Disparity: " + str(prediction_disp*100))

sys.stdout = sys.__stdout__

##########################
# Linear version, Table A.17
##########################

# write out results
sys.stdout = open("/REDACTED/excess_disp_decomp_lin.txt", "w")
### Oracle selection rates:
lin_black_audit_rate = regEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[2]
lin_nonblack_audit_rate = regEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[3]
lin_audit_disparity = regEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[0]

print("Oracle Selection Rate")
print("Black: " + str(lin_black_audit_rate*100))
print("Non-Black: " + str(lin_nonblack_audit_rate*100))


### FPR:
# subset to compliant TPs in the prediction data
comp_pop = df1[df1.aud_ind == 0]
comp_taxpayer_ids = comp_pop.taxpayer_id_new.unique().tolist()

df2_comp = df2[df2.taxpayer_id_new.isin(comp_taxpayer_ids)]


black_fpr = regEstimate(df2_comp, 'predicted_prob_black', 'aud_ind', 'base_weight')[2]
nonblack_fpr = regEstimate(df2_comp, 'predicted_prob_black', 'aud_ind', 'base_weight')[3]

print("\nFalse Positive Rate")
print("Black: " + str(black_fpr*100))
print("Non-Black: " + str(nonblack_fpr*100))


### FNR:
# subset to non-compliant TPs in the prediction data
noncomp_pop = df1[df1.aud_ind == 1]
noncomp_taxpayer_ids = noncomp_pop.taxpayer_id_new.unique().tolist()

df2_noncomp = df2[df2.taxpayer_id_new.isin(noncomp_taxpayer_ids)]

#df2_noncomp['not_aud'] = np.where(df2_noncomp.aud_ind == 0, 1, 0)

black_fnr = 1 - regEstimate(df2_noncomp, 'predicted_prob_black', 'aud_ind', 'base_weight')[2]
nonblack_fnr = 1 - regEstimate(df2_noncomp, 'predicted_prob_black', 'aud_ind', 'base_weight')[3]

print("\nFalse Negative Rate")
print("Black: " + str(black_fnr*100))
print("Non-Black: " + str(nonblack_fnr*100))

# combine values to get the thre terms of the decomposition
first_term = ((1-lin_nonblack_audit_rate)*(black_fpr - nonblack_fpr))*100
second_term  = (lin_nonblack_audit_rate * (nonblack_fnr - black_fnr))*100
third_term = -(lin_audit_disparity * (black_fnr + black_fpr))*100

print("\nExcess Disparity Contribution from ... ")
print("Scaled Difference in False Positive Rates: " + str(first_term))
print("Scaled Difference in False Negative Rates: " + str(second_term))
print("Attenuation of Oracle Disparity: " + str(third_term))

total_excess_disparity = first_term + second_term + third_term
print("\nTotal Excess Disparity: " + str(total_excess_disparity))

oracle_disp = regEstimate(df1, 'predicted_prob_black', 'aud_ind', 'base_weight')[0]
prediction_disp = regEstimate(df2, 'predicted_prob_black', 'aud_ind', 'base_weight')[0]

print("\nOracle Disparity: " + str(oracle_disp*100))
print("\nPrediction Model Disparity: " + str(prediction_disp*100))

sys.stdout = sys.__stdout__